###############################################################################
#                                                                             #
#           Skigin, Natán & Aníbal Pérez-Liñan                                #
#      "Preemptive Multipartism and Democratic Transitions"                   #
#                                                                             #
###############################################################################


rm(list=ls(all=TRUE))
library(caret); library(glmnet); library(randomForest); library(readstata13); library(rpart); library(mice); library(tidyverse); library(dplyr); library(vip); library(pdp); library(grid); library(gridExtra)


# Uncomment the following line and set your working directory
#setwd("")
path <- "base_Skigin & Pérez-Liñán_PSRM_final.dta"
dat1 <- read.dta13(path) # Read data
# Select variables
dat1 <- select(dat1, end1_gwf,end1_cgv, end1_sv, end2_gwf, end2_cgv, end2_sv,
             frag, unrest, region, e_miinterc , e_miinteco , pgdp , gpgdp,
             oilmin , popm, al_ethnic, al_religion, age_gwf, age2_gwf,
             age3_gwf, legis11, personal,enph, a_1, age_cgv, age2_cgv,
             age3_cgv, age_sv, age2_sv, age3_sv,
             gwf_topdown, gwf_bottomup,
             gwf_democratization, cgv_democratization, sv_democratization) 

set.seed(2222) # Set seed for reproducible results
# Complete missing data for Machine Learning
mm <- mice(dat1, printFlag = F)
dat1 <- mice::complete(mm)

aut = subset(dat1, a_1==1) # Subset of dictators


#######################################
#               Figure 3              #               
#######################################

### Random Forests
## GWF
aut$gwf_democratization = as.factor(aut$gwf_democratization)
# Run RF
rf.out <- randomForest(gwf_democratization ~ frag + unrest + region + 
                         e_miinterc + e_miinteco + pgdp + gpgdp + oilmin + 
                         popm + al_ethnic + al_religion +  age_gwf +  age2_gwf +
                         age3_gwf + personal + legis11 + enph, aut, 
                       method = "rf", tuneLength = 30,
                       ntree = 1500, importance = T,
                       trControl = trainControl(method="cv"))

# Partial dependence plot
p1.gwf <- partial(rf.out, "enph", rug = T, which.class = "1", prob = T,
                  data.frame(enph=seq(1,5,by=.1)))
p.gwf <- autoplot(p1.gwf, contour=T, center = T) + ylab("P(regime change)") +
  ylim(0, .5) +
  geom_line(size=1.5) + theme_bw() + xlab("ENP (House)") +
  theme(axis.text=element_text(size=14),
        axis.title=element_text(size=17),
        plot.title = element_text(hjust = 0.5,size=24)) +
  ggtitle("GWF")


## CGV
aut$cgv_democratization <- as.factor(aut$cgv_democratization) # CGV's DV as factor
rf.out.cgv <- randomForest(cgv_democratization ~ frag + unrest + region + 
                             e_miinterc + e_miinteco + pgdp + gpgdp + oilmin + 
                             popm + al_ethnic + al_religion +  age_gwf + 
                             age2_gwf + age3_gwf + personal + legis11 + enph, 
                           aut, method = "rf", tuneLength = 30,
                           ntree = 1500, importance = T,
                           trControl = trainControl(method="cv"))

# Partial dependence plot
p1.cgv <- partial(rf.out.cgv, "enph", rug = T, which.class = "1", prob = T,
                  data.frame(enph=seq(1,5,by=.1)))
p.cgv <- autoplot(p1.cgv, contour = T, center = T) +
  ylab("P(regime change)") +
  ylim(0, .5) +
  geom_line(size=1.5) + theme_bw() + xlab("ENP (House)") +
  theme(axis.text=element_text(size=14),
        axis.title=element_text(size=17),
        plot.title = element_text(hjust = 0.5,size=24)) +
  ggtitle("CGV")

## Svolik
aut$sv_democratization <- factor(aut$sv_democratization, levels = c(0,1)) # Svolik's DV as factor
rf.out.sv <- randomForest(sv_democratization ~ frag + unrest + region + 
                            e_miinterc + e_miinteco + pgdp + gpgdp + oilmin + 
                            popm + al_ethnic + al_religion +  age_gwf + 
                            age2_gwf + age3_gwf + personal + legis11 + enph, 
                          aut, method = "rf", tuneLength = 30,
                          ntree = 1500, importance = T,
                          trControl = trainControl(method="cv"))

# Partial dependence plot
p1.sv <- partial(rf.out.sv, "enph", rug = T, which.class = "1", prob = T,
                 pred.grid = data.frame(enph = seq(1,5,by = .1)))
p.sv= autoplot(p1.sv, contour=T, center = T) +
  ylab("P(regime change)") +
  ylim(0, .5) +
  geom_line(size=1.5) + theme_bw() + xlab("ENP (House)") +
  theme(axis.text=element_text(size=14),
        axis.title=element_text(size=17),
        plot.title = element_text(hjust = 0.5,size=24)) +
  ggtitle("Svolik")

# Combine plots from GWF, CGV & Svolik's PDPs
png("RF_CGV,GWF, SV.png", width = 465, height = 225, units='mm', res = 600)
grid.arrange(arrangeGrob(p.gwf),
             arrangeGrob(p.cgv),
             arrangeGrob(p.sv),
             ncol=3,
             top=textGrob("")
             #top=textGrob("Nonlinear Effects with Random Forests",gp=gpar(fontsize=20,font=3))
             )
dev.off()



#######################################
#               Figure A4             #               
#######################################
# Combine plots from GWF, CGV & Svolik's VarImpPlots
# Variable Importance Plots
names.plots <- c("frag" = "Fragmented opposition", "unrest" = "Unrest",
                 "region" = "Regional democracy", "enph" = "ENP (House)",
                 "e_miinterc" = "Internal conflict",
                 "e_miinteco" = "International conflict",
                 "pgdp" = "Per capita GDP", "gpgdp" = "Economic growth",
                 "oilmin" = "Mineral exports",
                 "popm" = "Population (millions)", "personal" = "Personal",
                 "al_ethnic" = "Ethnic Fractionalization",
                 "al_religion" = "Religious Fractionalization",
                 "age_sv" = "Age of the regime", "age2_sv" = "Age^2",
                 "age3_sv" = "Age^3",
                 "age_cgv" = "Age of the regime", "age2_cgv" = "Age^2",
                 "age3_cgv" = "Age^3",
                 "age_gwf" = "Age of the regime", "age2_gwf" = "Age^2",
                 "age3_gwf" = "Age^3",
                 "legis11" = "Legislature")


# make dataframes from importance() output for each datset
feat_imp_df_gwf <- importance(rf.out) %>% 
  data.frame() %>% 
  mutate(feature = row.names(.))

feat_imp_df_cgv <- importance(rf.out.cgv) %>% 
  data.frame() %>% 
  mutate(feature = row.names(.))

feat_imp_df_sv <- importance(rf.out.sv) %>% 
  data.frame() %>% 
  mutate(feature = row.names(.))

VI.gwf <- ggplot(feat_imp_df_gwf, aes(x = reorder(feature, MeanDecreaseGini), 
                        y = MeanDecreaseGini)) +
  geom_bar(stat='identity') +
  coord_flip() +
  theme_classic() +
  labs(
    x     = "Variable",
    y     = "Mean Decrease in Gini Score",
    title = "A4.1 GWF"
  ) +
  scale_x_discrete(labels = names.plots)  +
  theme(plot.title = element_text(hjust = 0.5, size=20,face="bold"),
        title =element_text(size=15,face="bold"),
        axis.title = element_text(size = 17),
        axis.text = element_text(size = 15))


VI.cgv <- ggplot(feat_imp_df_cgv, aes(x = reorder(feature, MeanDecreaseGini), 
                        y = MeanDecreaseGini)) +
  geom_bar(stat='identity') +
  coord_flip() +
  theme_classic() +
  labs(
    x     = "",
    y     = "Mean Decrease in Gini Score",
    title = "A4.2 CGV"
  ) +
  scale_x_discrete(labels = names.plots) +
  theme(plot.title = element_text(hjust = 0.5, size=20,face="bold"),
        title =element_text(size=15,face="bold"),
        axis.title = element_text(size = 17),
        axis.text = element_text(size = 15))




VI.sv <- ggplot(feat_imp_df_sv, aes(x = reorder(feature, MeanDecreaseGini), 
                            y = MeanDecreaseGini)) +
  geom_bar(stat='identity') +
  coord_flip() +
  theme_classic() +
  labs(
    x     = "",
    y     = "Mean Decrease in Gini Score",
    title = "A4.3 Svolik"
  ) +
  scale_x_discrete(labels = names.plots)  +
  theme(plot.title = element_text(hjust = 0.5, size=20,face="bold"),
        title =element_text(size=15,face="bold"),
        axis.title = element_text(size = 17),
        axis.text = element_text(size = 15))

png("VarImpPlot_GWF_CGV_SV.png", width = 465, height = 225, units='mm', res = 600)
grid.arrange(arrangeGrob(VI.gwf),
             arrangeGrob(VI.cgv),
             arrangeGrob(VI.sv),
             ncol=3,top=textGrob("")
             #top=textGrob("Variable Importance for Random Forests")
)
dev.off()


#######################################
#               Figure A5             #               
#######################################

### ROC-AUC
library(ROCR)

## For GWF
pred_gwf=predict(rf.out,type = "prob")
perf_gwf = prediction(pred_gwf[,2], aut$end1_gwf)
# 1. Area under curve
auc_gwf = performance(perf_gwf, "auc")
# 2. True Positive and Negative Rate
pred3_gwf = performance(perf_gwf, "tpr","fpr")

## For CGV
pred_cgv=predict(rf.out.cgv,type = "prob")
perf_cgv = prediction(pred_cgv[,2], aut$end1_cgv)
# 1. Area under curve
auc_cgv = performance(perf_cgv, "auc")
# 2. True Positive and Negative Rate
pred3_cgv = performance(perf_cgv, "tpr","fpr")

## For Svolik
pred_sv=predict(rf.out.sv,type = "prob")
perf_sv = prediction(pred_sv[,2], aut$end1_sv)
# 1. Area under curve
auc_sv = performance(perf_sv, "auc")
# 2. True Positive and Negative Rate
pred3_sv = performance(perf_sv, "tpr","fpr")

# Combine ROC plots
png("ROC_GWF_CGV_SV.png", width = 465, height = 225, units='mm', res = 1600)
par(mfrow=c(1,3), mar = c(5, 5, 4, 2), cex.axis=2)
plot(pred3_gwf,main="GWF",col=2,lwd=2, xlab ="", ylab ="",cex.main = 3.5) + 
  title(ylab =expression(italic("True positive rate")), line = 2.1, cex.lab =3) +
  title(xlab=expression(italic("False positive rate")), line=3.5,cex.lab=2.5) +
  abline(a=0,b=1,lwd=2,lty=2,col="gray")
plot(pred3_cgv,main="CGV",col=2,lwd=2, xlab ="",ylab = "",cex.main = 3.5) +
  title(xlab=expression(italic("False positive rate")),line=3.5, cex.lab = 2.5) +
  abline(a=0,b=1,lwd=2,lty=2,col="gray")
plot(pred3_sv,main="Svolik",col=2,lwd=2,xlab="", ylab ="", cex.main = 3.5) +
  title(xlab=expression(italic("False positive rate")),line =3.5,cex.lab = 2.5) +
  abline(a=0,b=1,lwd=2,lty=2,col="gray")
dev.off()
